"""
Phase 4b: KAOPEN Change Analysis
Development Threshold Paper

Tests whether change in capital account openness (not just level)
predicts threshold crossing. Motivated by event study showing
crossers and non-crossers diverge over time.
"""

import pandas as pd
import numpy as np
from scipy import stats
import statsmodels.api as sm
import sys
sys.path.insert(0, '/mnt/c/demographics_capital_flows/multilateral/src')
from model import PanelGLS

# ═══════════════════════════════════════════════════════════════════════
# Load data
# ═══════════════════════════════════════════════════════════════════════
full = pd.read_csv('/mnt/c/demographics_capital_flows/multilateral/140_country/data/processed/full_panel.csv')
panel = full[full['year'] <= 2024].copy()
cross_df = pd.read_csv('/mnt/c/demographics_capital_flows/development_threshold/data/crossing_data.csv')

LOWER = 9000
UPPER = 25000

zone_entrants = cross_df[cross_df['cross_9k'].notna() & (cross_df['status'] != 'Always above')].copy()

# ═══════════════════════════════════════════════════════════════════════
# 1. Compute ΔKAOPEN during transit for each country
# ═══════════════════════════════════════════════════════════════════════
print("=" * 70)
print("1. ΔKAOPEN DURING TRANSIT")
print("=" * 70)

for idx, row in zone_entrants.iterrows():
    iso = row['iso3']
    entry = row['cross_9k']
    c = panel[(panel['iso3'] == iso) & (panel['kaopen'].notna())].sort_values('year')

    if len(c) == 0:
        continue

    # KAOPEN at entry
    at_entry = c[c['year'] == entry]
    kao_entry = at_entry['kaopen'].values[0] if len(at_entry) > 0 else np.nan

    # KAOPEN at exit (or latest)
    exit_year = row['cross_25k'] if pd.notna(row.get('cross_25k')) else c['year'].max()
    at_exit = c[c['year'] == exit_year]
    if len(at_exit) == 0:
        at_exit = c.iloc[[-1]]
    kao_exit = at_exit['kaopen'].values[0] if len(at_exit) > 0 else np.nan

    # Change
    if pd.notna(kao_entry) and pd.notna(kao_exit):
        zone_entrants.loc[idx, 'kaopen_entry_val'] = kao_entry
        zone_entrants.loc[idx, 'kaopen_exit_val'] = kao_exit
        zone_entrants.loc[idx, 'delta_kaopen'] = kao_exit - kao_entry

    # KAOPEN change in first 10 years (early opening)
    early = c[(c['year'] >= entry) & (c['year'] <= entry + 10)]
    if len(early) >= 2:
        kao_early_end = early.iloc[-1]['kaopen']
        zone_entrants.loc[idx, 'delta_kaopen_10yr'] = kao_early_end - kao_entry

    # Average annual change
    transit_data = c[(c['year'] >= entry)]
    if pd.notna(row.get('cross_25k')):
        transit_data = c[(c['year'] >= entry) & (c['year'] <= row['cross_25k'])]
    if len(transit_data) >= 2:
        zone_entrants.loc[idx, 'kaopen_annual_change'] = (
            transit_data['kaopen'].diff().mean()
        )

# ═══════════════════════════════════════════════════════════════════════
# 2. Bivariate: ΔKAOPEN for crossers vs non-crossers
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("2. BIVARIATE: ΔKAOPEN CROSSERS vs NON-CROSSERS")
print("=" * 70)

for var, label in [('delta_kaopen', 'Total ΔKAOPEN'),
                    ('delta_kaopen_10yr', 'ΔKAOPEN first 10 years'),
                    ('kaopen_annual_change', 'Annual ΔKAOPEN'),
                    ('kaopen_entry_val', 'KAOPEN level at entry')]:
    if var not in zone_entrants.columns:
        continue
    c = zone_entrants[zone_entrants['exited_above'] == 1][var].dropna()
    nc = zone_entrants[zone_entrants['exited_above'] == 0][var].dropna()
    if len(c) >= 3 and len(nc) >= 3:
        t, p = stats.ttest_ind(c, nc, equal_var=False)
        w, wp = stats.mannwhitneyu(c, nc, alternative='two-sided')
        sig = '***' if min(p, wp) < 0.01 else '**' if min(p, wp) < 0.05 else '*' if min(p, wp) < 0.1 else ''
        print(f"  {label:30s}: crossers={c.mean():+.3f}, non-crossers={nc.mean():+.3f}, "
              f"diff={c.mean()-nc.mean():+.3f}, t-p={p:.4f}, W-p={wp:.4f} {sig}")

# ═══════════════════════════════════════════════════════════════════════
# 3. Logit: level vs change vs both
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("3. LOGIT: KAOPEN LEVEL vs CHANGE vs BOTH")
print("=" * 70)

base = zone_entrants[['exited_above', 'Z_1_entry', 'gdp_pc_ppp_entry']].copy()

# Model A: Level only (baseline)
sample_a = base.join(zone_entrants[['kaopen_entry_val']]).dropna()
if len(sample_a) >= 15:
    m_a = sm.Logit(sample_a['exited_above'],
                   sm.add_constant(sample_a[['Z_1_entry', 'gdp_pc_ppp_entry', 'kaopen_entry_val']])).fit(disp=0)
    print(f"\n  Model A: Z₁ + GDP + KAOPEN level (n={len(sample_a)}, R²={m_a.prsquared:.3f})")
    for var in ['Z_1_entry', 'gdp_pc_ppp_entry', 'kaopen_entry_val']:
        sig = '***' if m_a.pvalues[var] < 0.01 else '**' if m_a.pvalues[var] < 0.05 else '*' if m_a.pvalues[var] < 0.1 else ''
        print(f"    {var:25s}: coef={m_a.params[var]:8.4f}  p={m_a.pvalues[var]:.4f}{sig}")

# Model B: Change only
sample_b = base.join(zone_entrants[['delta_kaopen']]).dropna()
if len(sample_b) >= 15:
    m_b = sm.Logit(sample_b['exited_above'],
                   sm.add_constant(sample_b[['Z_1_entry', 'gdp_pc_ppp_entry', 'delta_kaopen']])).fit(disp=0)
    print(f"\n  Model B: Z₁ + GDP + ΔKAOPEN (n={len(sample_b)}, R²={m_b.prsquared:.3f})")
    for var in ['Z_1_entry', 'gdp_pc_ppp_entry', 'delta_kaopen']:
        sig = '***' if m_b.pvalues[var] < 0.01 else '**' if m_b.pvalues[var] < 0.05 else '*' if m_b.pvalues[var] < 0.1 else ''
        print(f"    {var:25s}: coef={m_b.params[var]:8.4f}  p={m_b.pvalues[var]:.4f}{sig}")

# Model C: Both level and change
sample_c = base.join(zone_entrants[['kaopen_entry_val', 'delta_kaopen']]).dropna()
if len(sample_c) >= 15:
    m_c = sm.Logit(sample_c['exited_above'],
                   sm.add_constant(sample_c[['Z_1_entry', 'gdp_pc_ppp_entry', 'kaopen_entry_val', 'delta_kaopen']])).fit(disp=0)
    print(f"\n  Model C: Z₁ + GDP + KAOPEN level + ΔKAOPEN (n={len(sample_c)}, R²={m_c.prsquared:.3f})")
    for var in ['Z_1_entry', 'gdp_pc_ppp_entry', 'kaopen_entry_val', 'delta_kaopen']:
        sig = '***' if m_c.pvalues[var] < 0.01 else '**' if m_c.pvalues[var] < 0.05 else '*' if m_c.pvalues[var] < 0.1 else ''
        print(f"    {var:25s}: coef={m_c.params[var]:8.4f}  p={m_c.pvalues[var]:.4f}{sig}")

# Model D: Early change (first 10 years)
sample_d = base.join(zone_entrants[['delta_kaopen_10yr']]).dropna()
if len(sample_d) >= 15:
    m_d = sm.Logit(sample_d['exited_above'],
                   sm.add_constant(sample_d[['Z_1_entry', 'gdp_pc_ppp_entry', 'delta_kaopen_10yr']])).fit(disp=0)
    print(f"\n  Model D: Z₁ + GDP + ΔKAOPEN(10yr) (n={len(sample_d)}, R²={m_d.prsquared:.3f})")
    for var in ['Z_1_entry', 'gdp_pc_ppp_entry', 'delta_kaopen_10yr']:
        sig = '***' if m_d.pvalues[var] < 0.01 else '**' if m_d.pvalues[var] < 0.05 else '*' if m_d.pvalues[var] < 0.1 else ''
        print(f"    {var:25s}: coef={m_d.params[var]:8.4f}  p={m_d.pvalues[var]:.4f}{sig}")

# ═══════════════════════════════════════════════════════════════════════
# 4. Sequence: does opening precede growth acceleration?
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("4. SEQUENCE: DOES OPENING PRECEDE OR FOLLOW GROWTH?")
print("=" * 70)

# For crossers: when did KAOPEN increase most relative to the crossing timeline?
# Compare KAOPEN change in first half vs second half of transit

sequence_data = []
for _, row in zone_entrants[zone_entrants['exited_above'] == 1].iterrows():
    iso = row['iso3']
    entry = row['cross_9k']
    exit_yr = row.get('cross_25k', np.nan)
    if pd.isna(exit_yr):
        continue

    c = panel[(panel['iso3'] == iso) & (panel['kaopen'].notna())].sort_values('year')
    midpoint = entry + (exit_yr - entry) / 2

    first_half = c[(c['year'] >= entry) & (c['year'] < midpoint)]
    second_half = c[(c['year'] >= midpoint) & (c['year'] <= exit_yr)]

    if len(first_half) >= 2 and len(second_half) >= 2:
        delta_first = first_half['kaopen'].iloc[-1] - first_half['kaopen'].iloc[0]
        delta_second = second_half['kaopen'].iloc[-1] - second_half['kaopen'].iloc[0]
        sequence_data.append({
            'iso3': iso,
            'delta_first_half': delta_first,
            'delta_second_half': delta_second,
            'transit_years': exit_yr - entry,
        })

seq_df = pd.DataFrame(sequence_data)
if len(seq_df) > 0:
    print(f"\nAmong crossers (n={len(seq_df)}):")
    print(f"  ΔKAOPEN first half of transit:  mean={seq_df['delta_first_half'].mean():+.3f}")
    print(f"  ΔKAOPEN second half of transit: mean={seq_df['delta_second_half'].mean():+.3f}")

    t, p = stats.ttest_rel(seq_df['delta_first_half'], seq_df['delta_second_half'])
    print(f"  Paired t-test (first vs second half): t={t:.3f}, p={p:.4f}")

    more_early = (seq_df['delta_first_half'] > seq_df['delta_second_half']).mean()
    print(f"  Fraction with more opening in first half: {more_early*100:.0f}%")

# ═══════════════════════════════════════════════════════════════════════
# 5. Granger-style: does KAOPEN change predict subsequent growth?
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("5. GRANGER-STYLE: ΔKAOPEN → GROWTH (PANEL, IN ZONE)")
print("=" * 70)

zone_panel = panel[(panel['gdp_pc_ppp'] >= LOWER) & (panel['gdp_pc_ppp'] <= UPPER)].copy()

# Compute lagged KAOPEN change (5-year)
zone_panel = zone_panel.sort_values(['iso3', 'year'])
zone_panel['kaopen_lag5'] = zone_panel.groupby('iso3')['kaopen'].shift(5)
zone_panel['delta_kaopen_5yr'] = zone_panel['kaopen'] - zone_panel['kaopen_lag5']

# Also compute forward growth (5-year avg)
zone_panel['growth_lead5'] = zone_panel.groupby('iso3')['rgdp_growth'].transform(
    lambda x: x.rolling(5, min_periods=3).mean().shift(-5)
)

# Regression: does past KAOPEN change predict future growth?
controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'kaopen']

# Past ΔKAOPEN → future growth
xvars = ['delta_kaopen_5yr', 'Z_1'] + controls
subset = zone_panel[['iso3', 'year', 'growth_lead5'] + xvars].dropna()
if len(subset) >= 50:
    gls = PanelGLS()
    gls.fit(subset['growth_lead5'], subset[xvars], subset['iso3'], subset['year'])
    print(f"\n  Past ΔKAOPEN(5yr) → Future growth(5yr) [in zone, N={gls.n_obs}]:")
    for i, var in enumerate(xvars):
        sig = '***' if gls.pvalues[i] < 0.01 else '**' if gls.pvalues[i] < 0.05 else '*' if gls.pvalues[i] < 0.1 else ''
        print(f"    {var:25s}: β={gls.beta[i]:8.4f}  p={gls.pvalues[i]:.4f}{sig}")

# Reverse: past growth → future ΔKAOPEN
zone_panel['growth_lag5'] = zone_panel.groupby('iso3')['rgdp_growth'].transform(
    lambda x: x.rolling(5, min_periods=3).mean().shift(5)
)
zone_panel['delta_kaopen_fwd5'] = zone_panel.groupby('iso3')['kaopen'].transform(
    lambda x: x.shift(-5) - x
)

xvars2 = ['growth_lag5', 'Z_1'] + controls
subset2 = zone_panel[['iso3', 'year', 'delta_kaopen_fwd5'] + xvars2].dropna()
if len(subset2) >= 50:
    gls2 = PanelGLS()
    gls2.fit(subset2['delta_kaopen_fwd5'], subset2[xvars2], subset2['iso3'], subset2['year'])
    print(f"\n  Past growth(5yr) → Future ΔKAOPEN(5yr) [in zone, N={gls2.n_obs}]:")
    for i, var in enumerate(xvars2):
        sig = '***' if gls2.pvalues[i] < 0.01 else '**' if gls2.pvalues[i] < 0.05 else '*' if gls2.pvalues[i] < 0.1 else ''
        print(f"    {var:25s}: β={gls2.beta[i]:8.4f}  p={gls2.pvalues[i]:.4f}{sig}")

# ═══════════════════════════════════════════════════════════════════════
# 6. Z₁ × ΔKAOPEN interaction: does opening matter more for older countries?
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("6. INTERACTION: Z₁ × ΔKAOPEN IN THE ZONE")
print("=" * 70)

zone_panel['Z_1_x_dkao'] = zone_panel['Z_1'] * zone_panel['delta_kaopen_5yr']

xvars3 = ['Z_1', 'delta_kaopen_5yr', 'Z_1_x_dkao'] + controls
subset3 = zone_panel[['iso3', 'year', 'rgdp_growth'] + xvars3].dropna()
if len(subset3) >= 50:
    gls3 = PanelGLS()
    gls3.fit(subset3['rgdp_growth'], subset3[xvars3], subset3['iso3'], subset3['year'])
    print(f"\n  Growth = f(Z₁, ΔKAOPEN, Z₁×ΔKAOPEN) [in zone, N={gls3.n_obs}]:")
    for i, var in enumerate(xvars3):
        sig = '***' if gls3.pvalues[i] < 0.01 else '**' if gls3.pvalues[i] < 0.05 else '*' if gls3.pvalues[i] < 0.1 else ''
        print(f"    {var:25s}: β={gls3.beta[i]:8.4f}  p={gls3.pvalues[i]:.4f}{sig}")

# Same for CA/GDP
subset4 = zone_panel[['iso3', 'year', 'ca_gdp'] + xvars3].dropna()
if len(subset4) >= 50:
    gls4 = PanelGLS()
    gls4.fit(subset4['ca_gdp'], subset4[xvars3], subset4['iso3'], subset4['year'])
    print(f"\n  CA = f(Z₁, ΔKAOPEN, Z₁×ΔKAOPEN) [in zone, N={gls4.n_obs}]:")
    for i, var in enumerate(xvars3):
        sig = '***' if gls4.pvalues[i] < 0.01 else '**' if gls4.pvalues[i] < 0.05 else '*' if gls4.pvalues[i] < 0.1 else ''
        print(f"    {var:25s}: β={gls4.beta[i]:8.4f}  p={gls4.pvalues[i]:.4f}{sig}")

# Save results
results = {
    'bivariate': zone_entrants[['iso3', 'exited_above', 'kaopen_entry_val', 'delta_kaopen',
                                 'delta_kaopen_10yr', 'kaopen_annual_change']].dropna(subset=['delta_kaopen']),
}
results['bivariate'].to_csv(
    '/mnt/c/demographics_capital_flows/development_threshold/output/tables/table16_kaopen_change.csv', index=False)

print("\n✓ Phase 4b complete.")
